{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "33adcdb3",
   "metadata": {
    "id": "33adcdb3"
   },
   "source": [
    "# Video Detection Demo with PytorchWildlife"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d7ca914",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/16-OjFVQ6nopuP-gfqofYBBY00oIgbcr1?usp=sharing)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c534c504",
   "metadata": {
    "id": "c534c504"
   },
   "source": [
    "This tutorial guides you on how to use PyTorchWildlife for video detection and classification. We will go through the process of setting up the environment, defining the detection and classification models, as well as performing inference and saving the results in an annotated video.\n",
    "\n",
    "## Prerequisites\n",
    "Install PytorchWildlife running the following commands\n",
    "\n",
    "Also, make sure you have a CUDA-capable GPU if you intend to run the model on a GPU. In Google Colab, click on Runtime in the menu bar, select Change runtime type, choose GPU under Hardware accelerator, and then click Save.\n",
    "\n",
    "This notebook can also run on CPU.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "rD4wFhkognga",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "rD4wFhkognga",
    "outputId": "8e8051a5-0129-455e-e3b7-2ff9fb63ca6b"
   },
   "outputs": [],
   "source": [
    "!sudo apt-get update -y\n",
    "!sudo apt-get install python3.8 python3.8-dev python3.8-distutils libpython3.8-dev\n",
    "\n",
    "#change alternatives\n",
    "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.8 1\n",
    "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 2\n",
    "\n",
    "#Check that it points at the right location\n",
    "!python3 --version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "yGf75u15grdD",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "yGf75u15grdD",
    "outputId": "06806167-421c-4909-e283-edf441294700"
   },
   "outputs": [],
   "source": [
    "# install pip\n",
    "!curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py\n",
    "!python3 get-pip.py --force-reinstall\n",
    "\n",
    "#install colab's dependencies\n",
    "!python3 -m pip install ipython ipython_genutils ipykernel jupyter_console prompt_toolkit httplib2 astor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "IA0-0gF4gxVW",
   "metadata": {
    "id": "IA0-0gF4gxVW"
   },
   "outputs": [],
   "source": [
    "# link to the old google package\n",
    "!ln -s /usr/local/lib/python3.10/dist-packages/google \\\n",
    "       /usr/local/lib/python3.8/dist-packages/google\n",
    "\n",
    "!sed -i \"s/from IPython.utils import traitlets as _traitlets/import traitlets as _traitlets/\" /usr/local/lib/python3.8/dist-packages/google/colab/*.py\n",
    "!sed -i \"s/from IPython.utils import traitlets/import traitlets/\" /usr/local/lib/python3.8/dist-packages/google/colab/*.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "koCscpXyg2fd",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "id": "koCscpXyg2fd",
    "outputId": "1eda3ac5-87b0-4771-c2c1-56d794c0f7cc"
   },
   "outputs": [],
   "source": [
    "#Install PytorchWildlife\n",
    "!pip install pytorchwildlife"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "OrE7HHflglSk",
   "metadata": {
    "id": "OrE7HHflglSk"
   },
   "source": [
    "Now, you must go to Runtime and restart your session to continue.\n",
    "\n",
    "## Importing libraries\n",
    "First, let's import the necessary libraries and modules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a28c392c",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "a28c392c",
    "outputId": "a4d18bb3-6139-4da5-9cab-40c3a94e449d"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import supervision as sv\n",
    "import torch\n",
    "from PytorchWildlife.models import detection as pw_detection\n",
    "from PytorchWildlife.models import classification as pw_classification\n",
    "from PytorchWildlife import utils as pw_utils"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "802747c2",
   "metadata": {
    "id": "802747c2"
   },
   "source": [
    "## Model Initialization\n",
    "We'll  define the device to run the models and then we will initialize the models for both video detection and classification."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "BFoXl71ihXeM",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "BFoXl71ihXeM",
    "outputId": "b2966f95-03bd-4e70-97b5-83b5589e0bb1"
   },
   "outputs": [],
   "source": [
    "!wget -O repo.zip https://github.com/microsoft/CameraTraps/archive/refs/heads/main.zip\n",
    "!unzip -q repo.zip -d ./"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd069110",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "dd069110",
    "outputId": "574a5505-2dae-4ed2-b56c-94174e0c7739"
   },
   "outputs": [],
   "source": [
    "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "if DEVICE == \"cuda\":\n",
    "    torch.cuda.set_device(0)\n",
    "\n",
    "SOURCE_VIDEO_PATH = \"/content/CameraTraps-main/demo/demo_data/videos/opossum_example.MP4\"\n",
    "TARGET_VIDEO_PATH = \"/content/CameraTraps-main/demo/demo_data/videos/opossum_processed.MP4\"\n",
    "\n",
    "# Verify the checkpoints directory exists to save the model weights if pretrained=True\n",
    "os.makedirs(os.path.join(torch.hub.get_dir(), \"checkpoints\"), exist_ok=True)\n",
    "\n",
    "# Initializing the MegaDetectorV6 model for image detection\n",
    "# Valid versions are MDV6-yolov9-c, MDV6-yolov9-e, MDV6-yolov10-c, MDV6-yolov10-e or MDV6-rtdetr-c\n",
    "detection_model = pw_detection.MegaDetectorV6(device=DEVICE, pretrained=True, version=\"MDV6-yolov10-e\")\n",
    "\n",
    "# Uncomment the following line to use MegaDetectorV5 instead of MegaDetectorV6\n",
    "#detection_model = pw_detection.MegaDetectorV5(device=DEVICE, pretrained=True, version=\"a\")\n",
    "\n",
    "# Initializing the AI4GOpossum model for image classification\n",
    "classification_model = pw_classification.AI4GOpossum(device=DEVICE, pretrained=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3cd6262a",
   "metadata": {
    "id": "3cd6262a"
   },
   "source": [
    "## Video Processing\n",
    "For each frame in the video, we'll apply detection and classification, and then annotate the frame with the results. The processed video will be saved with annotated detections and classifications."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e6147a40",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e6147a40",
    "outputId": "2c430b84-6d6b-47e4-9168-5cfb792e3d1f"
   },
   "outputs": [],
   "source": [
    "box_annotator = sv.BoxAnnotator(thickness=4)\n",
    "lab_annotator = sv.LabelAnnotator(text_color=sv.Color.BLACK, text_thickness=4, text_scale=2)\n",
    "\n",
    "def callback(frame: np.ndarray, index: int) -> np.ndarray:\n",
    "    results_det = detection_model.single_image_detection(frame, img_path=index)\n",
    "    labels = []\n",
    "    for xyxy in results_det[\"detections\"].xyxy:\n",
    "        cropped_image = sv.crop_image(image=frame, xyxy=xyxy)\n",
    "        results_clf = classification_model.single_image_classification(cropped_image)\n",
    "        labels.append(\"{} {:.2f}\".format(results_clf[\"prediction\"], results_clf[\"confidence\"]))\n",
    "    annotated_frame = lab_annotator.annotate(\n",
    "        scene=box_annotator.annotate(\n",
    "            scene=frame,\n",
    "            detections=results_det[\"detections\"],\n",
    "        ),\n",
    "        detections=results_det[\"detections\"],\n",
    "        labels=labels,\n",
    "    )\n",
    "    return annotated_frame \n",
    "\n",
    "pw_utils.process_video(source_path=SOURCE_VIDEO_PATH, target_path=TARGET_VIDEO_PATH, callback=callback, target_fps=5)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e270f0f",
   "metadata": {
    "id": "8e270f0f"
   },
   "source": [
    "### Copyright (c) Microsoft Corporation. All rights reserved.\n",
    "### Licensed under the MIT License."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
